{- «•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»

    Copyright © 2011 - 2021, Ingo Wechsung
    All rights reserved.

    Redistribution and use in source and binary forms, with or
    without modification, are permitted provided that the following
    conditions are met:

        Redistributions of source code must retain the above copyright
        notice, this list of conditions and the following disclaimer.

        Redistributions in binary form must reproduce the above
        copyright notice, this list of conditions and the following
        disclaimer in the documentation and/or other materials provided
        with the distribution. Neither the name of the copyright holder
        nor the names of its contributors may be used to endorse or
        promote products derived from this software without specific
        prior written permission.

    THIS SOFTWARE IS PROVIDED BY THE
    COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
    WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
    PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER
    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
    SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
    LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
    USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
    AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
    LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
    IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
    THE POSSIBILITY OF SUCH DAMAGE.

    «•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•» -}

{--
 * Type unification and utility functions for the type checker.
 -}


package frege.compiler.tc.Util where

import frege.Prelude hiding(<+>)

import Data.TreeMap (TreeMap, values, lookup, insert, keys, 
                        including, union, contains)
import Data.List as DL(unique, uniq, sort, elemBy, partition)

import frege.compiler.enums.Flags as Compilerflags(TRACET)

import  Compiler.types.Positions
import  Compiler.types.QNames
import  Compiler.types.Types
import  Compiler.types.Expression
import  Compiler.types.Symbols
import  Compiler.types.Global as G

import  Compiler.common.Errors as E()
import  Compiler.common.Binders
import  Compiler.common.Types as TT (betterReadable, substRho)

import  Compiler.classes.Nice
import  Compiler.instances.Nicer

import  Compiler.Kinds as KI()

import Lib.PP (msgdoc, text, </>, <+/>, <+>, <>, nest)
import frege.compiler.Utilities as U except (print, println)
import frege.compiler.Javatypes


data Expected t = Check t | Infer t

newSigmaTyVar d = ForAll [] <$>  newRhoTyVar d
newRhoTyVar   d = RhoTau [] <$>  newMeta2 d
newMeta       d = Meta      <$>  newFlexiTyVar d
newMeta2      d = Meta      <$>  newFlexiTyVar (TVar Position.null (snd d) (fst d))
newFlexiTyVar TVar{kind=k, var=n} = do u <- uniqid; stio (Flexi u n k)
newFlexiTyVar _ = error "no tyvar"
newRigidTyVar TVar{kind=k, var=n} = do u <- uniqid; stio (Rigid u n k)
newRigidTyVar _ = error "no tyvar"

instSigma ex sig erho = do
        g <- getST
        E.logmsg TRACET (getpos ex) (text ("instSigma: " ++ ex.nice g ++ "  ::  "
                                        ++ nice sig g))
        rho <- instantiate sig
        instRho ex rho erho

instExplain ex ty (Infer _) = do
    g <- getST
    E.explain (getpos ex) (msgdoc (is ex ++ "  " ++ ex.nice g ++ "  ::  " ++ ty.nicer g))
instExplain ex ty (Check s) = do
    g <- getST
    E.explain (getpos ex) (msgdoc (is ex ++ "  " ++ ex.nice g ++ "  ::  " ++ ty.nicer g
        ++ "  expected is  " ++ s.nicer g))


instRho :: Expr -> Rho -> Expected Rho -> StG Expr
instRho ex ty ety = do
        g <- getST
        E.logmsg TRACET (getpos ex) (text ("instRho initial: " ++ ex.nice g ++ "  ::  "
                                        ++ ty.nice g))
        -- ty <- contexts ex ty    -- make context canonical, merge others from ex and simplify
        -- E.logmsg TRACET (getpos ex) (text ("instRho contexts: " ++ ex.nice g
        --                                  ++ "  ::  " ++ ty.nice g))
        case ety of
            Check r   -> do subsCheckRR ex ty r
                            g ← getST   -- make unifications visible
                            let mty = ty.{context ← mergeCtx r.context . reducedCtxs g}
                            instExplain ex mty ety
                            E.logmsg TRACET (getpos ex) (text ("instRho subschecked: "
                                                ++ ex.nice g ++ "  ::  " ++ mty.nice g))
                            stio (ex.{typ=Just (ForAll [] mty)})
            Infer r -> do
                instExplain ex ty ety
                -- subsCheckRR ex ty r
                -- doio (ref.put (Just ty))
                pure ex.{typ=Just (ForAll [] ty.{context ← reducedCtxs g})}

instPatSigma pat sigma esig = do
        g <- getST
        E.logmsg TRACET (getpos pat) (text ("InstPatSigma: " ++ pat.nice g ++ " :: " ++ sigma.nice g))
        instExplain pat sigma esig
        case esig of
            Check s ->   subsCheck pat sigma s
            Infer r ->   subsCheck pat sigma r

-- subsCheck exp s1 (RhoTau 
subsCheck exp s1 s2 = do
    g <- getST
    E.logmsg TRACET (exp.getpos) (text ("subsCheck: " ++ s1.nice g ++ " <= " ++ s2.nice g))
    (!skol_tvs, !skol) ← skolemise s2
    off ← instantiate s1
    subsCheckRR exp off skol
    g <- getST
    let tvs1 = sigmaTvs g s1
        tvs2 = sigmaTvs g s2
    let !esc_tvs = tvs1 ++ tvs2
        !bad_tvs = [ tv | tv <- skol_tvs, tv `elem` esc_tvs ]
    unless (null bad_tvs) do
        E.logmsg TRACET (getpos exp) (text ("skolTvs:  " ++ joined ", " (map (flip nice g) skol_tvs)))
        E.logmsg TRACET (getpos exp) (text ("sigm1Tvs: " ++ joined ", " (map (flip nice g) tvs1)))
        E.logmsg TRACET (getpos exp) (text ("sigm2Tvs: " ++ joined ", " (map (flip nice g) tvs2)))
        E.logmsg TRACET (getpos exp) (text ("bad_tvs:  " ++ joined ", " (map (flip nice g) bad_tvs)))
        polyerr exp s1 s2
    --g ← getST
    --let tycs    = reducedCtxs g off.context
    --    etycs   = reducedCtxs g skol.context
    --    implies = impliesG g
    --    badcs   = [ etycx | etycx ← etycs, not (any (`implies` etycx) tycs)]
    --unless (null badcs) do
    --    E.error (getpos exp) (text "Instantiation error in " <+> text (nicer exp g)
    --        </> (text "of type " <+> text (nicer off g))
    --        </> (text "cannot be instantiated at " <+> text (nicer skol g))
    --        </> (text "because expected constraints are missing:")
    --        </> (text (nicerctx badcs g))
    --        </> (text "(Try to use an annotated local binding instead of the plain expression.)")  
    --        )

  where
    polyerr !exp !s1 !s2 = do
        g <- getST
        let !pos = getpos exp
        E.error pos (text "Type" <+/> text (s1.nicer g)
            </> (text "inferred from " <+/> text (exp.nice g) <+/> text "is not as polymorphic as")
            </> (text "expected type " <+/> text (s2.nicer g)))

subsCheckSR exp sig rho = do
    g <- getST
    E.logmsg TRACET (getpos exp) (text ("subsCheckSR: " ++ sig.nice g ++ " :> " ++ rho.nice g))
    off <- instantiate sig
    subsCheckRR exp off rho

-- check constraints (used only from Classes.fr)
-- offered type must not be less constrained than expected
-- which means, all contexts from the expected type must be implied by the offered ones
checkConstraints exp s1 s2 = do
        (_,skol) <- skolemise s2
        offered   <- instantiate s1
        subsCheckRR exp offered skol         -- unify types
        g   <- getST
        let off = canonicContext g offered
            ety = canonicContext g skol

        let pos = getpos exp
        E.logmsg TRACET pos (text ("expr context:     "   ++ nicectx off.context g))
        E.logmsg TRACET pos (text ("expected context: "   ++ nicectx ety.context g))
        -- find constraints that mention skolemised vars that are not in the expected type
        let implies = impliesG g
            bad = [ ctx |   ctx <- ety.context,
                            not (any (`implies` ctx) off.context) ]

        unless (null bad) do
            g <- getST
            E.error (getpos exp) (msgdoc "offered type is less constrained than expected type"
                    </> (text "offered:  " <+> text (nicer off g))
                    </> (text "expected: " <+> text (nicer ety g)))
        stio ()

-- subsCheckRR :: (Positioned e, Nice e) => e -> Rho -> Rho -> StG ()
subsCheckRR ex ty ety = do
        g <- getST
        E.logmsg TRACET (getpos ex) (text ("subsCheckRR: " ++ ty.nice g ++ " <= " ++ ety.nice g))
        subsCheckRR' ex ty ety
    where
            {-
             - implement rule FUN if one of the types is a RhoFun
             -}
            subsCheckRR' exp t1 (RhoFun _ a2 r2) = do
                        (a1,r1) <- unifyFun exp t1
                        subsCheckFun exp a1 r1 a2 r2
            subsCheckRR' exp (RhoFun _ a1 r1) t2 = do
                        (a2,r2) <- unifyFun exp t2
                        subsCheckFun exp a1 r1 a2 r2
            {-
             | otherwise revert to ordinary unification
             -}
            subsCheckRR' expr (RhoTau _ off) exp = unify expr off exp.tau

subsCheckFun exp s1 r1 s2 r2 = do
            subsCheck   exp s2 s1
            subsCheckRR exp r1 r2

unifyFun exp (RhoFun cx sigma res) = pure (sigma, res)
unifyFun exp (RhoTau cx tau) =  do
        g <- getST
        arg_ty <- newMeta2 ("arg", KType)
        res_ty <- newMeta2 ("res", KType)
        let !funty = Tau.tfun arg_ty res_ty
        b <- unified exp tau funty
        unless b do
            g <- getST
            E.error (getpos exp) (message g funty)
            E.hint  (getpos exp) (text ("too many or too few arguments perhaps?"))
        pure (ForAll [] (RhoTau [] arg_ty), RhoTau [] res_ty)
    where
        message g funty = part1 </> part2 </> part3
            where
                part1 = text "type error in" <+> text exp.is </> nest 2 (nicest g exp)
                part2 = text "type is apparently " <+> text (tau.nicer g)
                part3 = text "does not match function type " <+> text (better g funty) 



--- return the unbound sigma type variables
sigmaTvs :: Global -> Sigma -> [MetaTv]
sigmaTvs g  = keys . getSigmaTvs g

rhoTvs :: Global -> Rho -> [MetaTv]
rhoTvs g = keys . getRhoTvs g

tauTvs g = keys . getTauTvs g

ctxTvs :: Global -> Context -> [MetaTv]
ctxTvs g ctx = tauTvs g ctx.tau

getSigmaTvs g (ForAll _ rho) = getRhoTvs g rho
getRhoTvs g (RhoFun cs sig rho) = let
        csTvs = map (getCtxTvs g) cs
        sTvs  = getSigmaTvs g sig
        rTvs  = getRhoTvs g rho
    in (fold union (sTvs `union` rTvs) csTvs)
getRhoTvs g (RhoTau cs tau) = let
        csTvs = map (getCtxTvs g) cs
        tTvs  = getTauTvs g tau
    in (fold union tTvs csTvs)

getCtxTvs g = getTauTvs g . Context.tau

getTauTvs g tau = getTauTvsT g TreeMap.empty tau

getTauTvsT g t (TApp a b) = 
    let ta =  getTauTvsT g t a
    in getTauTvsT g ta b
getTauTvsT g t (TCon {pos}) = t
getTauTvsT g t (TVar {pos}) = t
getTauTvsT g t (Meta tv) = 
     case Global.bound g tv of
        Just ty -> getTauTvsT g t ty
        _ -> (t `including` tv)
getTauTvsT g t (TSig s) = fold including t (sigmaTvs g s)

{--
 * get the type variables that are mentioned in the current environment
 * except for symbol @sid@
 -}
envTvs g sid = [ m |
            q   <- g.typEnv, 
            sym <- g.findit q,
            sym <- (g.follow  sym),    -- follow aliases
            sym.{expr?},
            sym.sid != sid,
            m   <- sigmaTvs g sym.typ ]

--- read a type var monadically
readTv :: MetaTv -> StG (Maybe Tau)
readTv tv = do
    g <- getST
    return (g.bound tv)
writeTv :: MetaTv -> Tau -> StG ()
writeTv (Flexi{uid}) tau = changeST _.{tySubst <- TreeMap.insertkvI uid tau}
writeTv tv _ = do
    g <- getST
    E.fatal Position.null (text ("write to rigid tyvar " ++ tv.nice g))

skolemise :: Sigma -> StG ([MetaTv], Rho)
skolemise (ForAll [] ty) = (,) [] <$> instWildRho ty
skolemise (ForAll ns ty) = do
    tvs <- mapSt newRigidTyVar ns
    let tree = TreeMap.fromList (zip (map _.var ns) (map Meta tvs))
    rho  ←  instWildRho (substRho tree ty)
    pure (tvs, rho)

--- replace all 'TVar' with flexible 'Meta' type variables.
--- Wild card conctructs are taken care of (through 'instWild')
instantiate :: Sigma -> StG Rho
instantiate (ForAll [] ty) = instWildRho ty
instantiate (ForAll ns ty) = do
    tvs <- mapSt newMeta ns
    let tree = TreeMap.fromList (zip (map _.var ns) tvs)
    instWildRho (substRho tree ty)

--- Substitute all wild cards with flexible Meta types
--- This cannot be done with 'substRho' and friends, since wild cards are unnamed.
--- Every wild card gets its own meta type var.
instWildRho RhoFun{context, sigma, rho} = do
    c  ← mapM instWildCtx context
    s  ← instWildSigma sigma
    r  ← instWildRho rho
    pure RhoFun{context=c, sigma=s, rho=r}
instWildRho RhoTau{context, tau} = do
    c  ← mapM instWildCtx context
    t  ← instWildTau tau
    pure RhoTau{context=c, tau=t}

--- see 'instWildRho'
--- don't descend into higher rank types yet (or should we?) 
instWildSigma (sigma@ForAll{bound, rho})
    | null bound = ForAll bound <$> instWildRho rho
    | otherwise  = pure sigma

instWildCtx (ctx@Ctx{pos, cname, tau}) = ctx.{tau=} <$> instWildTau tau

instWildTau (tv@TVar{pos, kind, var})  = case tv.wildTau of
    Just _ | KGen tau ← kind = mapM instWildTau tau >>= newMeta . tv.{kind=} . KGen
    other                    = pure tv  -- ordinary TVars must not be instantiated/skolemised here
instWildTau (TApp t1 t2) = do
    t1'  ← instWildTau t1
    t2'  ← instWildTau t2
    pure (TApp t1' t2')
instWildTau (tcon@TCon{pos, name}) = pure tcon
instWildTau (tsig@TSig s)          = TSig <$> instWildSigma s
instWildTau (meta@Meta m)          = case m.kind of
    KGen tau = mapM instWildTau tau >>= pure . Meta . m.{kind=} . KGen
    other    = pure meta
{--
 * like instantiate, but give the tvs back
 -}
instantiateTvs (ForAll [] ty) = stio ([], ty)
instantiateTvs (ForAll ns ty) = do
    tvs <- mapSt newMeta ns
    let tree = TreeMap.fromList (zip (map _.var ns) tvs)
        rho  = substRho tree ty
    stio (tvs, rho)


unify :: (Positioned a, Nice a) => a -> Tau -> Tau -> StG ()
unify ex t1 t2 = do
        r <- unified ex t1 t2
        unless r do
            g <- getST
            let pos = getpos ex
            E.error pos (part1 g </> part2 g </> part3 g)
    where
        -- better2 = betterReadable t2
        part1 g = text "type error in" <+> text ex.is </> nest 2  (nicest g ex)
        part2 g = text "type is :" <+> text (better g t1)
        part3 g = text "expected:" <+> text (better g t2)

better g  = flip nicer g . betterReadable g

--- resolve 'Meta' type, if possible
--- If this returns a 'Meta' then it is unbound. 
reduced :: Tau -> Global -> Tau
reduced tau g 
    | Meta tv <- tau = 
        case g.bound tv of
            Nothing -> tau
            Just ty -> reduced ty g
    | otherwise  = tau


unified :: (Positioned a, Nice a) => a -> Tau -> Tau -> StG Bool
unified ex tau1 tau2 = do
    g <- getST
    E.logmsg TRACET (getpos ex) (text ("unify:  " ++ nice tau1 g ++ "  and  " ++ nice tau2 g))

    let t1 = reduced tau1 g
        t2 = reduced tau2 g

    when (badType t1 || badType t2) do
        E.fatal (getpos ex) (text "bad types in unification, turn on -xt for details")

    case (t1, t2) of
        (Meta tv, ty)        | Meta tv2 <- ty, tv == tv2 = return true
                             | tv.isFlexi = unifyVar ex tv (Right ty)
        (ty, Meta tv)        | tv.isFlexi = unifyVar ex tv (Left ty)
        -- (TFun a b, TFun c d)              = liftM2 (&&) (unified ex a c) (unified ex b d)
        (TCon{}, TCon{})              = if t1.name == t2.name 
                                                then return true 
                                                else unifyTCon (getpos ex) t1.name t2.name
        (TApp a b, TApp c d)              = do
                                                left <- unified ex a c
                                                if left then unified ex b d
                                                else return false
        (Meta (Rigid _ s _), ty)          = do
                                            E.error (getpos ex) (msgdoc ("type  `" ++ ty.nicer g
                                                ++ "` is not as polymorphic as suggested "
                                                ++ " in the annotation where just  `" ++ s
                                                ++ "`  is announced."))
                                            E.hint  (getpos ex) (msgdoc "The inferred type must be at least as polymorphic as the annotated one.")
                                            stio false
        (ty, Meta (Rigid _ s _))          = do
                                            E.error (getpos ex) (msgdoc ("type  `" ++ ty.nicer g
                                                ++ "` is not as polymorphic as suggested "
                                                ++ " in the annotation where just  `" ++ s
                                                ++ "`  is announced."))
                                            E.hint  (getpos ex) (msgdoc "The inferred type must be at least as polymorphic as the annotated one.")
                                            stio false
        _                                 = stio false
  where
    badType :: Tau -> Bool
    badType (TVar {pos}) = true
    badType _            = false
    -- unifyTCon will only be called with real type names (no aliases)
    -- It returns true if both 'TCon's describe native types and the 
    -- first one is a subtype of the other *and* if their mutability matches
    unifyTCon :: Position -> QName -> QName -> StG Bool
    unifyTCon pos t1 t2 = do
        g ← getST
        E.logmsg TRACET (getpos ex) (text ("unifyj:  " ++ nice t1 g ++ "  and  " ++ nice t2 g))
        sym1 <- U.findT t1
        case sym1.nativ of
            Just c1 -> do
                sym2 <- U.findT t2
                case sym2.nativ of
                    Just c2
                        -- Don't unify A and B when either one is based on a primitive type
                        -- not even (and foremost) if it's the same one!  
                      | c1 `elem` primitiveTypes = pure false
                      | c2 `elem` primitiveTypes = pure false 
                      | otherwise = do
                        E.logmsg TRACET (getpos ex) (text "unifyj:  " <> text c1 <> text "  and  " <> text c2)
                        let result = subTypeOf g c1 c2
                        when (not result) do
                            E.hint pos (text ("supertypes of " ++ c1
                                ++ ": " ++ joined ", " (U.supersOfNativ c1 g)))
                            E.hint pos (text ("does not contain " ++ c2))
                        return result 
                    nothing -> return false
            nothing -> return false


unifyVar :: (Positioned a, Nice a) => a -> MetaTv -> Either Tau Tau -> StG Bool
unifyVar ex tv lrtau = do
    bound <- readTv tv
    case bound of
        Just ty -> case lrtau of
            Left tau  -> unified ex tau ty
            Right tau -> unified ex ty tau
        Nothing -> either unbound unbound lrtau
  where
    unbound (t2@Meta tv2) 
        -- TODO: work out unification of proper union types
        | (tau1:_) ← (Meta tv).bounds, (tau2:_) ← t2.bounds = do
                u ← unified ex tau1 tau2
                when (u) do
                    writeTv tv t2
                    g <- getST
                    E.logmsg TRACET (getpos ex) (text ("unifyVar: " 
                            ++ show tv.uid ++ " " 
                            ++ tv.nice g
                            ++ " :: " ++ show tv.kind))
                pure true  
        | Nothing ← KI.unifyKind tv.kind tv2.kind = do
            g ← getST
            E.error (getpos ex) (text "Kind error in unification of "
                </> (text (tv.nicer g) <+> text "::" <+> text (tv.kind.nicer g) <+> text " with")
                </> (text (tv2.nicer g) <+> text "::" <+> text (tv2.kind.nicer g)))
            pure false
    unbound tau = do            -- unifyUnboundVar
        g <- getST
        let tauTvs = getTauTvs g tau
            tvar   = TVar{pos=getpos tau, kind=KVar, var=" occurs in type "}    -- trick to make 'better' work
            tapp   = TApp (TApp (Meta tv) tvar) tau                             -- fake type for showing
        if tauTvs `contains` tv then do
                E.error (getpos ex) (
                    text (better g tapp)    -- "t1 occurs in type (t2->t1)"
                    </> (text "caused by" <+> text ex.is)
                    </> nicest g ex)
                stio false
            else case tv.kind of 
                -- KGen t -> unifyKinded t tau  
                other -> do
                    writeTv tv tau
                    g <- getST
                    E.logmsg TRACET (getpos ex) (text ("unifyVar: " 
                            ++ show tv.uid ++ " " 
                            ++ tv.nice g
                            ++ " :: " ++ show tv.kind))
                    stio true
    -- We have tv≤Foo and Bar
    -- Unification is ok when Bar is a subtype of Foo
    -- We need to expand the MetaTv one step 
    unifyKinded t tau = do
        st ← substMeta tv.uid (Meta tv) t
        unified ex st tau

--- substitute MetaTV with given UID in a Tau
substMeta ∷ Int → Tau → Tau → StG Tau
substMeta uid rep tau = case tau of
    TVar{kind=KGen ts} = mapM (substMeta uid rep) ts >>= pure . tau.{kind=} . KGen
    TApp a b           = liftM2 TApp (substMeta uid rep a) (substMeta uid rep b)
    TCon{}             = pure tau
    Meta tv | uid == tv.uid = pure rep
    Meta tv           = do
        bound ← readTv tv
        case bound of 
            Just ty -> substMeta uid rep ty      -- skip enclosing Metas
            Nothing -> case tv.kind of
                KGen ts ->  mapM (substMeta uid rep) ts >>= pure . Meta . tv.{kind=} . KGen
                other   ->  pure tau
    other              = pure tau




{--
    eliminate any substitutions in the type
    -}
zonkSigma :: Sigma -> StG Sigma
zonkSigma (ForAll ns ty) = do rho <- zonkRho ty; stio (ForAll ns rho)
cleanSigma (ForAll ns ty) = do rho <- cleanRho ty; stio (ForAll ns rho)

zonkRho   :: Rho   -> StG Rho
zonkRho (RhoFun ctxs arg res) = liftM3 RhoFun (mapSt zonkCtx ctxs) (zonkSigma arg) (zonkRho res)
zonkRho (RhoTau ctxs tau)     = liftM2 RhoTau (mapSt zonkCtx ctxs) (zonkTau tau)
cleanRho (RhoFun ctxs arg res) = liftM3 RhoFun (zonkCtxs ctxs) (zonkSigma arg) (zonkRho res)
cleanRho (RhoTau ctxs tau)     = liftM2 RhoTau (zonkCtxs ctxs) (zonkTau tau)

zonkCtxs  :: [Context] -> StG [Context]
zonkCtxs ctxs = do
    ctxs <- mapSt zonkCtx ctxs
    return (filter withVars ctxs)
     
zonkCtx ctx = do
    let tau = Context.tau ctx 
    tau <- zonkTau tau
    return ctx.{tau}

withVars = withTauVars . Context.tau
withTauVars (TCon {})  = false
-- withTauVars (TFun a b) = withTauVars a || withTauVars b
withTauVars (TApp a b) = withTauVars a || withTauVars b
withTauVars vars       = true

zonkTau   :: Tau   -> StG Tau
-- zonkTau (TFun arg res)   = liftM2 TFun (zonkTau arg) (zonkTau res)
zonkTau (TApp a b)       = liftM2 TApp (zonkTau a)   (zonkTau b)
zonkTau (m@Meta tv)      = do
        mbtau <- readTv tv
        case mbtau of
            Nothing | Just x ← m.wildTau = do
                                        k ← zonkKind tv.kind
                                        pure TVar{pos=getpos m, kind=k, var=x}
                    | otherwise = pure m
            Just ty -> do      -- short out multiple hops
                            ty <- zonkTau ty
                            writeTv tv ty
                            stio ty
zonkTau other = stio other      -- TVar and TCon

zonkKind ∷ Kind → StG Kind
zonkKind (KGen t) = KGen <$> mapM zonkTau t
zonkKind other    = pure other


substRigidSigma [] sigma = sigma
substRigidSigma bound (ForAll b rho) = ForAll b
        (substRigidRho (filter (`notElem` map _.var b) bound) rho)
        
substRigidRho [] rho = rho
substRigidRho bound (RhoFun ctxs sig rho) = RhoFun 
        (map (substRigidCtx bound) ctxs)
        (substRigidSigma bound sig)
        (substRigidRho   bound rho)
        
substRigidRho bound (RhoTau ctxs tau) = 
        RhoTau (map (substRigidCtx bound) ctxs)  (substRigidTau bound tau)

substRigidCtx :: [String] -> Context -> Context
substRigidCtx bound ctx = ctx.{tau <- substRigidTau bound}

substRigidTau bound (TApp a b) = TApp 
        (substRigidTau bound a)
        (substRigidTau bound b)
        
substRigidTau bound (meta@Meta (Rigid {hint, kind}))           -- this is what happens in the end
    | hint `elem` bound = (TVar Position.null kind hint)
substRigidTau bound tau = tau



quantified = quantifiedExcept 0
{-
 * quantify a bunch of rho types
 * do not take a certain symbol into account
 -}
quantifiedExcept :: Int -> [Rho] -> StG [Sigma]
quantifiedExcept exc rhos = do
        g <- getST
        let rhosTvs = map (rhoTvs g) rhos
            eTvs    = envTvs g exc     -- take all symbols into account
        let
            -- all unbound tv used in the Rhos except those in the environment
            allTvs = unique [ tv | tvs <- rhosTvs,
                                        tv <- tvs,
                                        MetaTv.isFlexi tv,
                                        isNothing (Meta tv).wildTau,
                                        tv `notElem` eTvs ]
            -- select a type variable name for each tv
            newvars = filter (`notElem` used) (allBinders g)
            bound   = zip newvars allTvs
            -- make sigma for rho with the tvs that appear in that rho
            mksig ∷ [(String,MetaTv)] → (Rho,[MetaTv]) → StG Sigma
            mksig bound (rho,tvs) = do
                    nvz ← mapM (\tv → zonkKind tv.kind >>= pure . tv.{kind=}) nv
                    rhoz ← zonkRho rho
                    pure (ForAll nvz rhoz) 
                where nv = [ TVar{pos, kind=MetaTv.kind v, var=n} | (n,v) <- bound, v `elem` tvs]
                      pos = getpos rho
        foreach bound bind                              -- actually write TVars in the MetaTvs
        mapSt (mksig bound) (zip rhos rhosTvs)          -- create (and return) the sigmas
    where
        pos = Position.null -- Position.merges (map U.rhoPos rhos)
        -- TVar names used in the Rhos
        used = [ u | r <- rhos, u <- tyVarBndrs r ]
        bind ∷ (String,MetaTv) → StG ()
        bind (var,tv) = case tv.kind of
            KGen t → do
                        writeTv tv (TVar {pos, var, kind=KVar})
                        t' ← mapM zonkTau t
                        writeTv tv (TVar {pos, var, kind=KGen t'})
            other  → writeTv tv (TVar {pos, var, kind=tv.kind})
            

quantify rho = do
    sigs <- quantified [rho]
    stio (head sigs)


canonicSignature sig = (instantiate sig >>= zonkRho) >>= quantify


{--
 * get all the binders used in ForAlls in the type so that when
 * quantifying an outer forall we can avoid these inner ones
 -}
tyVarBndrs ty = (uniq • sort) (bndrs ty) where
    bndrs (RhoFun _ (ForAll tvs arg) res)
                = (map _.var tvs ++ bndrs arg) ++ bndrs res
    bndrs _     = []

{--
 * get the open 'Context's from a canonic 'Rho' in an typechecked 'Expr'
 *
 * This are the contexts that have not been checked in instanceOf
 -}
exContext :: Global -> Expr -> [Context] 
exContext g ex = case Expr.typ ex of
    Just (ForAll _ rho) -> reducedCtxs g rho.context     -- rho is canonical
    Nothing -> Prelude.error ("exContext: typ=Nothing in " ++ ex.nice g)

{--
 * enrich the type by all contexts found in any subexpr
 -}
contexts ex typ = do
        g <- getST
        let pos = getpos ex
            rho = canonicContext g typ
        case ex of
            Vbl {pos} -> simplify pos rho
            Con {pos} -> simplify pos rho
            Lit {pos} -> simplify pos rho
            Ann ex ty -> do
                let ectx = exContext g ex
                simplify pos rho.{context <- mergeCtx ectx}
            App fun arg _ -> do
                let fctx = exContext g fun
                let actx = exContext g arg
                simplify pos rho.{context <- mergeCtx (mergeCtx fctx actx)}
            Let {env,ex} -> do
                let ectx = exContext g ex
                syms <- mapSt U.findV env
                subexs <- sequence [ ex | SymV {expr = Just ex} <- syms ]
                let rctxss = map (exContext g) subexs
                let rctxs = [ ctx | ctxs <- rctxss, ctx <- ctxs ]
                -- take only contexts that have at least 1 flexi tv
                    rtvss = map (ctxTvs g) rctxs
                let ctxs =  [ ctx | (ctx,tvs) <- zip rctxs rtvss, any (MetaTv.isFlexi) tvs]
                let merged = fold mergeCtx rho.context [ectx,ctxs]
                simplify pos rho.{context=merged}
            Lam {ex} -> do
                let ectx = exContext g ex
                E.logmsg TRACET (getpos ex) (text ("contexts: lamrho="
                    ++ nicectx rho.context g
                    ++ ", ectx=" ++ nicectx ectx g))
                simplify pos rho.{context <- mergeCtx ectx}
            Ifte c t e _ -> do
                let ctxs   = map (exContext g) [c,t,e]
                let merged = fold mergeCtx rho.context ctxs
                simplify pos rho.{context=merged}
            Case {ex,alts} -> do
                let ectx = exContext g ex
                    ctxs = map (exContext g • CAlt.ex) alts
                let merged = fold mergeCtx rho.context (ectx:ctxs)
                simplify pos rho.{context=merged}
            Mem  {ex} -> do         -- can happen when x.xyz does not typecheck
                let ectx = exContext g ex
                simplify pos rho.{context <- mergeCtx ectx}
            inv -> do
                g <- getST
                E.fatal (getpos inv) (text ("contexts: Invalid expression " ++ inv.nice g))

canonicContext :: Global -> Rho -> Rho
canonicContext g (RhoTau ctxs tau) = 
        let rctxs = reducedCtxs g ctxs
        in (RhoTau rctxs tau)
canonicContext g (RhoFun ctxs (ForAll bs rhoA) rhoB) = 
        let rctxs = reducedCtxs g ctxs
            rho1  = canonicContext g rhoA
            rho2  = canonicContext g rhoB
        in (RhoFun rctxs {-merged-} (ForAll bs rho1.{context=[]}) rho2.{context=[]})

{--
 * Reduce a 'Tau' to a form where only unbound 'Meta's occur.
 *
 * This is different from 'zonkTau' insofar as no meta type variables are changed.
 -}
reducedTau ∷ Global → Tau → Tau
reducedTau g (TApp a b)     = TApp (reducedTau g a) (reducedTau g b)
reducedTau g (meta@Meta{})  = case reduced meta g of
                                t@Meta{} → t                    -- unbound
                                tau      → reducedTau g tau     -- could be   <X> -> <Y>
reducedTau g (t@TVar{})     = t
reducedTau g (t@TCon{})     = t
reducedTau g (t@TSig{})     = t

{--
 * reduce a list of 'Context's, so that only unbound 'Meta' remain
 -}
reducedCtxs g ctxs = map (reducedCtx g) ctxs

{-- reduce a 'Context' so that we will not see any meta variables -}
reducedCtx :: Global -> Context -> Context
reducedCtx g ctx = ctx.{tau <- reducedTau g}


{--
 * merge two already reduced 'Context's
 -}
mergeCtx [] ctx = ctx
mergeCtx (c:cs) ctx
    -- Context.checked c = filter (not • sameCtx c) (mergeCtx cs ctx)
    | elemBy sameCtx c ctx = mergeCtx cs ctx
    | otherwise = c : mergeCtx cs ctx

sameCtx :: Context -> Context -> Bool
sameCtx ca cb = ca.{cname?} && cb.{cname?} && ca.cname == cb.cname && sameTau ca.tau cb.tau

--- check identity of 2 'Tau's. This works only on 'reducedTau's.
sameTau :: Tau -> Tau -> Bool
sameTau (Meta a) (Meta b) = a == b
sameTau (TVar {var=a}) (TVar {var=b})   = a == b
sameTau (TCon {name=a}) (TCon {name=b}) = a == b
sameTau (TApp a b) (TApp c d) = sameTau a c && sameTau b d
-- sameTau (TFun a b) (TFun c d) = sameTau a c && sameTau b d
sameTau _ _ = false

--- if /C/ is a super class of /D/, then /D tau/ implies /C tau/ for the same tau
--- example: 'Ord' a implies 'Eq' a
impliesG g (Ctx _ d t1) (Ctx _ c t2) = isSuper c g d && sameTau t1 t2

{--
 * simplify a 'Rho' 'Context'
 * - if context is of the form @C a@ or @C (a b ...)@, check that it is not implied
 *   by other contexts, i.e. (Ord a, Num a) is simplified to Num a
 * - if context is of the form @C (T ...)@ make sure that instance exists and add
 *   its implications, i.e. Eq [a] will add Eq a.
 -}
simplify :: Position -> Rho -> StG Rho
simplify pos rho = do
    g <- getST
    E.logmsg TRACET pos (text ("simplify " ++ rho.nice g))
    let
        implies = impliesG g
        single, singler :: Context -> String
        single  ctx = nicectx  [ctx] g
        singler ctx = nicerctx [ctx] g
        context = reducedCtxs g rho.context
    case context of
        [] -> stio rho.{context}
        (ctx:ctxs) -> case ctx.tau.flat of
            [] -> Prelude.error "Tau.flat returns empty list"       -- avoid case warning
            t1:ts | isVarMeta t1 = if (any (`implies` ctx) ctxs2)
                       then do
                            E.logmsg TRACET pos (text ("dropped: " ++ single ctx ++ "  (implied)"))
                            simplify pos rho.{context=ctxs2}            -- drop ctx as it is implied
                       else do
                            E.logmsg TRACET pos (text ("retained: " ++ single ctx))
                            rho <- simplify pos rho.{context=ctxs2}
                            stio rho.{context <- (ctx:)}
                  | otherwise = do
                        implications <- instanceOf ctx.pos ctx.cname ctx.tau
                        let reducedctxs = reducedCtxs g (ctx:implications)
                        let !ctx          = head reducedctxs
                            !implications = tail reducedctxs
                        E.logmsg TRACET pos (text ("implications of " ++ single ctx
                            ++ " are " ++ nicectx implications g))
                        when (not (null implications)) do
                            E.explain pos (text ("the implications of " ++ singler ctx
                                ++ "  are  " ++ joined ", " (map singler implications)))
                        rho <- simplify pos rho.{context = ctxs2 ++ implications}
                        -- tau <- reducedTau ctx.tau
                        stio rho -- .{context <- (ctx.{checked=true, tau}:)}
                  where ctxs2 = filter (not • (ctx `implies`)) ctxs

--- tell if this is either a 'TVar' or a 'Meta'
isVarMeta (TVar {var}) = true
isVarMeta (Meta _)     = true
isVarMeta _            = false


instanceOf :: Position -> QName -> Tau -> StG [Context]
instanceOf pos qn tau = do
    g <- getST
    E.logmsg TRACET pos (text ("is " ++ nice tau g ++ " instance of " ++ nice qn g ++ "?"))
    E.explain pos (text ("type  " ++ nicer tau g ++ "  must be instance of  " ++ nice qn g))
    let tcon =  head tau.flat
        showtn (TName pack base) = pack.raw ++ "." ++ base
        showtn _ = error "showtn: must be type name"
    case tcon of
        TCon {name} -> do
            E.logmsg TRACET pos (text ("tcon is " ++ showtn name))
            clas <- findC qn
            E.logmsg TRACET pos (text ("class " ++ showtn clas.name ++ " has instances for "
                                          ++ joined ", " (map (showtn • fst) clas.insts)))
            case filter ((name ==) • fst) clas.insts of
                [] -> do
                    E.error pos (msgdoc (nicer tau g ++ " is not an instance of " ++ nice qn g))
                    stio []
                (_,iname):_ -> do
                    inst <- findI iname
                    E.logmsg TRACET pos (text ("found instance " ++ nicer inst.typ g))
                    E.explain pos (text ("there is an instance for " ++ nicer inst.typ g))
                    rho <- instantiate inst.typ
                    -- Eq 42 => [42]
                    E.explain pos (text ("we assume there is a variable inst::" ++ nicer tau g
                        ++ "  and check if it unifies with " ++ rho.nicer g))
                    let inst = Local 0 "inst"
                    subsCheckRR (Vbl pos inst Nothing) (RhoTau [] tau) rho
                    stio (map _.{pos} rho.context)
        _ -> do
            E.error pos (msgdoc (nicer tau g ++ " is not, and cannot be, an instance of " ++ nice qn g))
            stio []

